///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// HLOperationLowerExtension.cpp                                             //
// Copyright (C) Microsoft Corporation. All rights reserved.                 //
// This file is distributed under the University of Illinois Open Source     //
// License. See LICENSE.TXT for details.                                     //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include "dxc/HLSL/HLOperationLowerExtension.h"

#include "dxc/DXIL/DxilModule.h"
#include "dxc/DXIL/DxilOperations.h"
#include "dxc/HLSL/HLModule.h"
#include "dxc/HLSL/HLOperationLower.h"
#include "dxc/HLSL/HLOperations.h"
#include "dxc/HlslIntrinsicOp.h"

#include "llvm/ADT/StringRef.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/YAMLParser.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/ADT/SmallString.h"

using namespace llvm;
using namespace hlsl;

LLVM_ATTRIBUTE_NORETURN static void ThrowExtensionError(StringRef Details)
{
    std::string Msg = (Twine("Error in dxc extension api: ") + Details).str();
    throw hlsl::Exception(DXC_E_EXTENSION_ERROR, Msg);
}

// The lowering strategy format is a string that matches the following regex:
//
//      [a-z](:(?P<ExtraStrategyInfo>.+))?$
//
// The first character indicates the strategy with an optional : followed by
// additional lowering information specific to that strategy.
//
ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
  if (strategy.size() < 1)
    return Strategy::Unknown;

  switch (strategy[0]) {
    case 'n': return Strategy::NoTranslation;
    case 'r': return Strategy::Replicate;
    case 'p': return Strategy::Pack;
    case 'm': return Strategy::Resource;
    case 'd': return Strategy::Dxil;
    default: break;
  }
  return Strategy::Unknown;
}

llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
  switch (strategy) {
    case Strategy::NoTranslation: return "n";
    case Strategy::Replicate:     return "r";
    case Strategy::Pack:          return "p";
    case Strategy::Resource:      return "m"; // m for resource method
    case Strategy::Dxil:          return "d";
    default: break;
  }
  return "?";
}

static std::string ParseExtraStrategyInfo(StringRef strategy)
{
    std::pair<StringRef, StringRef> SplitInfo = strategy.split(":");
    return SplitInfo.second;
}

ExtensionLowering::ExtensionLowering(Strategy strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp,  HLResourceLookup &hlResourceLookup)
  : m_strategy(strategy), m_helper(helper), m_hlslOp(hlslOp), m_hlResourceLookup(hlResourceLookup)
  {}

ExtensionLowering::ExtensionLowering(StringRef strategy, HLSLExtensionsCodegenHelper *helper, OP& hlslOp, HLResourceLookup &hlResourceLookup)
  : ExtensionLowering(GetStrategy(strategy), helper, hlslOp, hlResourceLookup)
  {
    m_extraStrategyInfo = ParseExtraStrategyInfo(strategy);
  }

llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
  switch (m_strategy) {
  case Strategy::NoTranslation: return NoTranslation(CI);
  case Strategy::Replicate:     return Replicate(CI);
  case Strategy::Pack:          return Pack(CI);
  case Strategy::Resource:      return Resource(CI);
  case Strategy::Dxil:          return Dxil(CI);
  default: break;
  }
  return Unknown(CI);
}

llvm::Value *ExtensionLowering::Unknown(CallInst *CI) {
  assert(false && "unknown translation strategy");
  return nullptr;
}

// Interface to describe how to translate types from HL-dxil to dxil.
class FunctionTypeTranslator {
public:
  // Arguments can be exploded into multiple copies of the same type.
  // For example a <2 x i32> could become { i32, 2 } if the vector
  // is expanded in place or { i32, 1 } if the call is replicated.
  struct ArgumentType {
    Type *type;
    int  count;

    ArgumentType(Type *ty, int cnt = 1) : type(ty), count(cnt) {}
  };

  virtual ~FunctionTypeTranslator() {}

  virtual Type *TranslateReturnType(CallInst *CI) = 0;
  virtual ArgumentType TranslateArgumentType(Value *OrigArg) = 0;
};

// Class to create the new function with the translated types for low-level dxil.
class FunctionTranslator {
public:
  template <typename TypeTranslator>
  static Function *GetLoweredFunction(CallInst *CI, ExtensionLowering &lower) {
    TypeTranslator typeTranslator;
    return GetLoweredFunction(typeTranslator, CI, lower);
  }
  
  static Function *GetLoweredFunction(FunctionTypeTranslator &typeTranslator, CallInst *CI, ExtensionLowering &lower) {
    FunctionTranslator translator(typeTranslator, lower);
    return translator.GetLoweredFunction(CI);
  }

  virtual ~FunctionTranslator() {}

protected:
  FunctionTypeTranslator &m_typeTranslator;
  ExtensionLowering &m_lower;

  FunctionTranslator(FunctionTypeTranslator &typeTranslator, ExtensionLowering &lower)
    : m_typeTranslator(typeTranslator)
    , m_lower(lower)
  {}

  Function *GetLoweredFunction(CallInst *CI) {
    // Ge the return type of replicated function.
    Type *RetTy = m_typeTranslator.TranslateReturnType(CI);
    if (!RetTy)
      return nullptr;

    // Get the Function type for replicated function.
    FunctionType *FTy = GetFunctionType(CI, RetTy);
    if (!FTy)
      return nullptr;

    // Create a new function that will be the replicated call.
    AttributeSet attributes = GetAttributeSet(CI);
    std::string name = m_lower.GetExtensionName(CI);
    return cast<Function>(CI->getModule()->getOrInsertFunction(name, FTy, attributes));
  }

  virtual FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) {
    // Create a new function type with the translated argument.
    SmallVector<Type *, 10> ParamTypes;
    ParamTypes.reserve(CI->getNumArgOperands());
    for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
      Value *OrigArg = CI->getArgOperand(i);
      FunctionTypeTranslator::ArgumentType newArgType = m_typeTranslator.TranslateArgumentType(OrigArg);
      for (int i = 0; i < newArgType.count; ++i) {
        ParamTypes.push_back(newArgType.type);
      }
    }

    const bool IsVarArg = false;
    return FunctionType::get(RetTy, ParamTypes, IsVarArg);
  }

  AttributeSet GetAttributeSet(CallInst *CI) {
    Function *F = CI->getCalledFunction();
    AttributeSet attributes;
    auto copyAttribute = [=, &attributes](Attribute::AttrKind a) {
      if (F->hasFnAttribute(a)) {
        attributes = attributes.addAttribute(CI->getContext(), AttributeSet::FunctionIndex, a);
      }
    };
    copyAttribute(Attribute::AttrKind::ReadOnly);
    copyAttribute(Attribute::AttrKind::ReadNone);
    copyAttribute(Attribute::AttrKind::ArgMemOnly);
    copyAttribute(Attribute::AttrKind::NoUnwind);

    return attributes;
  }
};

///////////////////////////////////////////////////////////////////////////////
// NoTranslation Lowering.
class NoTranslationTypeTranslator : public FunctionTypeTranslator {
  virtual Type *TranslateReturnType(CallInst *CI) override {
    return CI->getType();
  }
  virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
    return ArgumentType(OrigArg->getType());
  }
};

llvm::Value *ExtensionLowering::NoTranslation(CallInst *CI) {
  Function *NoTranslationFunction = FunctionTranslator::GetLoweredFunction<NoTranslationTypeTranslator>(CI, *this);
  if (!NoTranslationFunction)
    return nullptr;

  IRBuilder<> builder(CI);
  SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
  return builder.CreateCall(NoTranslationFunction, args);
}

///////////////////////////////////////////////////////////////////////////////
// Replicated Lowering.
enum {
  NO_COMMON_VECTOR_SIZE = 0x0,
};
// Find the vector size that will be used for replication.
// The function call will be replicated once for each element of the vector
// size.
static unsigned GetReplicatedVectorSize(llvm::CallInst *CI) {
  unsigned commonVectorSize = NO_COMMON_VECTOR_SIZE;
  Type *RetTy = CI->getType();
  if (RetTy->isVectorTy())
    commonVectorSize = RetTy->getVectorNumElements();
  for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
    Type *Ty = CI->getArgOperand(i)->getType();
    if (Ty->isVectorTy()) {
      unsigned vectorSize = Ty->getVectorNumElements();
      if (commonVectorSize != NO_COMMON_VECTOR_SIZE && commonVectorSize != vectorSize) {
        // Inconsistent vector sizes; need a different strategy.
        return NO_COMMON_VECTOR_SIZE;
      }
      commonVectorSize = vectorSize;
    }
  }

  return commonVectorSize;
}

class ReplicatedFunctionTypeTranslator : public FunctionTypeTranslator {
  virtual Type *TranslateReturnType(CallInst *CI) override {
    unsigned commonVectorSize = GetReplicatedVectorSize(CI);
    if (commonVectorSize == NO_COMMON_VECTOR_SIZE)
      return nullptr;

    // Result should be vector or void.
    Type *RetTy = CI->getType();
    if (!RetTy->isVoidTy() && !RetTy->isVectorTy())
      return nullptr;

    if (RetTy->isVectorTy()) {
      RetTy = RetTy->getVectorElementType();
    }

    return RetTy;
  }

  virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
    Type *Ty = OrigArg->getType();
    if (Ty->isVectorTy()) {
      Ty = Ty->getVectorElementType();
    }

    return ArgumentType(Ty);
  }

};

class ReplicateCall {
public:
  ReplicateCall(CallInst *CI, Function &ReplicatedFunction)
    : m_CI(CI)
    , m_ReplicatedFunction(ReplicatedFunction)
    , m_numReplicatedCalls(GetReplicatedVectorSize(CI))
    , m_ScalarizeArgIdx()
    , m_Args(CI->getNumArgOperands())
    , m_ReplicatedCalls(m_numReplicatedCalls)
    , m_Builder(CI)
  {
    assert(m_numReplicatedCalls != NO_COMMON_VECTOR_SIZE);
  }

  Value *Generate() {
    CollectReplicatedArguments();
    CreateReplicatedCalls();
    Value *retVal = GetReturnValue();
    return retVal;
  }

private:
  CallInst *m_CI;
  Function &m_ReplicatedFunction;
  unsigned m_numReplicatedCalls;
  SmallVector<unsigned, 10> m_ScalarizeArgIdx;
  SmallVector<Value *, 10> m_Args;
  SmallVector<Value *, 10> m_ReplicatedCalls;
  IRBuilder<> m_Builder;

  // Collect replicated arguments.
  // For non-vector arguments we can add them to the args list directly.
  // These args will be shared by each replicated call. For the vector
  // arguments we remember the position it will go in the argument list.
  // We will fill in the vector args below when we replicate the call
  // (once for each vector lane).
  void CollectReplicatedArguments() {
    for (unsigned i = 0; i < m_CI->getNumArgOperands(); ++i) {
      Type *Ty = m_CI->getArgOperand(i)->getType();
      if (Ty->isVectorTy()) {
        m_ScalarizeArgIdx.push_back(i);
      }
      else {
        m_Args[i] = m_CI->getArgOperand(i);
      }
    }
  }

  // Create replicated calls.
  // Replicate the call once for each element of the replicated vector size.
  void CreateReplicatedCalls() {
    for (unsigned vecIdx = 0; vecIdx < m_numReplicatedCalls; vecIdx++) {
      for (unsigned i = 0, e = m_ScalarizeArgIdx.size(); i < e; ++i) {
        unsigned argIdx = m_ScalarizeArgIdx[i];
        Value *arg = m_CI->getArgOperand(argIdx);
        m_Args[argIdx] = m_Builder.CreateExtractElement(arg, vecIdx);
      }
      Value *EltOP = m_Builder.CreateCall(&m_ReplicatedFunction, m_Args);
      m_ReplicatedCalls[vecIdx] = EltOP;
    }
  }

  // Get the final replicated value.
  // If the function is a void type then return (arbitrarily) the first call.
  // We do not return nullptr because that indicates a failure to replicate.
  // If the function is a vector type then aggregate all of the replicated
  // call values into a new vector.
  Value *GetReturnValue() {
    if (m_CI->getType()->isVoidTy())
      return m_ReplicatedCalls.back();

    Value *retVal = llvm::UndefValue::get(m_CI->getType());
    for (unsigned i = 0; i < m_ReplicatedCalls.size(); ++i)
      retVal = m_Builder.CreateInsertElement(retVal, m_ReplicatedCalls[i], i);

    return retVal;
  }
};

// Translate the HL call by replicating the call for each vector element.
//
// For example,
//
//    <2xi32> %r = call @ext.foo(i32 %op, <2xi32> %v)
//    ==>
//    %r.1 = call @ext.foo.s(i32 %op, i32 %v.1)
//    %r.2 = call @ext.foo.s(i32 %op, i32 %v.2)
//    <2xi32> %r.v.1 = insertelement %r.1, 0, <2xi32> undef
//    <2xi32> %r.v.2 = insertelement %r.2, 1, %r.v.1
//
// You can then RAWU %r with %r.v.2. The RAWU is not done by the translate function.
Value *ExtensionLowering::Replicate(CallInst *CI) {
  Function *ReplicatedFunction = FunctionTranslator::GetLoweredFunction<ReplicatedFunctionTypeTranslator>(CI, *this);
  if (!ReplicatedFunction)
    return NoTranslation(CI);

  ReplicateCall replicate(CI, *ReplicatedFunction);
  return replicate.Generate();
}

///////////////////////////////////////////////////////////////////////////////
// Packed Lowering.
class PackCall {
public:
  PackCall(CallInst *CI, Function &PackedFunction)
    : m_CI(CI)
    , m_packedFunction(PackedFunction)
    , m_builder(CI)
  {}

  Value *Generate() {
    SmallVector<Value *, 10> args;
    PackArgs(args);
    Value *result = CreateCall(args);
    return UnpackResult(result);
  }
  
  static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
    assert(vecTy->isVectorTy());
    Type *elementTy = vecTy->getVectorElementType();
    unsigned numElements = vecTy->getVectorNumElements();
    SmallVector<Type *, 4> elements;
    for (unsigned i = 0; i < numElements; ++i)
      elements.push_back(elementTy);

    return StructType::get(vecTy->getContext(), elements);
  }

private:
  CallInst *m_CI;
  Function &m_packedFunction;
  IRBuilder<> m_builder;

  void PackArgs(SmallVectorImpl<Value*> &args) {
    args.clear();
    for (Value *arg : m_CI->arg_operands()) {
      if (arg->getType()->isVectorTy())
        arg = PackVectorIntoStruct(m_builder, arg);
      args.push_back(arg);
    }
  }

  Value *CreateCall(const SmallVectorImpl<Value*> &args) {
    return m_builder.CreateCall(&m_packedFunction, args);
  }

  Value *UnpackResult(Value *result) {
    if (result->getType()->isStructTy()) {
      result = PackStructIntoVector(m_builder, result);
    }
    return result;
  }

  static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
    assert(structTy->isStructTy());
    return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
  }

  static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
    StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
    Value *packed = UndefValue::get(structTy);

    unsigned numElements = structTy->getStructNumElements();
    for (unsigned i = 0; i < numElements; ++i) {
      Value *element = builder.CreateExtractElement(vec, i);
      packed = builder.CreateInsertValue(packed, element, { i });
    }

    return packed;
  }

  static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
    Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
    Value *packed = UndefValue::get(vecTy);

    unsigned numElements = vecTy->getVectorNumElements();
    for (unsigned i = 0; i < numElements; ++i) {
      Value *element = builder.CreateExtractValue(strukt, i);
      packed = builder.CreateInsertElement(packed, element, i);
    }

    return packed;
  }
};

class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
  virtual Type *TranslateReturnType(CallInst *CI) override {
    return TranslateIfVector(CI->getType());
  }
  virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
    return ArgumentType(TranslateIfVector(OrigArg->getType()));
  }

  Type *TranslateIfVector(Type *ty) {
    if (ty->isVectorTy())
      ty = PackCall::ConvertVectorTypeToStructType(ty);
    return ty;
  }
};

Value *ExtensionLowering::Pack(CallInst *CI) {
  Function *PackedFunction = FunctionTranslator::GetLoweredFunction<PackedFunctionTypeTranslator>(CI, *this);
  if (!PackedFunction)
    return NoTranslation(CI);

  PackCall pack(CI, *PackedFunction);
  Value *result = pack.Generate();
  return result;
}

///////////////////////////////////////////////////////////////////////////////
// Resource Lowering.

// Modify a call to a resouce method. Makes the following transformation:
//
// 1. Convert non-void return value to dx.types.ResRet.
// 2. Expand vectors in place as separate arguments.
//
// Example
// -----------------------------------------------------------------------------
//
//  %0 = call <2 x float> MyBufferOp(i32 138, %class.Buffer %3, <2 x i32> <1 , 2> )
//  %r = call %dx.types.ResRet.f32 MyBufferOp(i32 138, %dx.types.Handle %buf, i32 1, i32 2 )
//  %x = extractvalue %r, 0
//  %y = extractvalue %r, 1
//  %v = <2 x float> undef
//  %v.1 = insertelement %v,   %x, 0
//  %v.2 = insertelement %v.1, %y, 1
class ResourceMethodCall {
public:
  ResourceMethodCall(CallInst *CI)
    : m_CI(CI)
    , m_builder(CI)
  { }

  virtual ~ResourceMethodCall() {}

  virtual Value *Generate(Function *explodedFunction) {
    SmallVector<Value *, 16> args;
    ExplodeArgs(args);
    Value *result = CreateCall(explodedFunction, args);
    result = ConvertResult(result);
    return result;
  }
  
protected:
  CallInst *m_CI;
  IRBuilder<> m_builder;

  void ExplodeArgs(SmallVectorImpl<Value*> &args) {
    for (Value *arg : m_CI->arg_operands()) {
      // vector arg: <N x ty> -> ty, ty, ..., ty (N times)
      if (arg->getType()->isVectorTy()) {
        for (unsigned i = 0; i < arg->getType()->getVectorNumElements(); i++) {
          Value *xarg = m_builder.CreateExtractElement(arg, i);
          args.push_back(xarg);
        }
      }
      // any other value: arg -> arg
      else {
        args.push_back(arg);
      }
    }
  }

  Value *CreateCall(Function *explodedFunction, ArrayRef<Value*> args) {
    return m_builder.CreateCall(explodedFunction, args);
  }

  Value *ConvertResult(Value *result) {
    Type *origRetTy = m_CI->getType();
    if (origRetTy->isVoidTy())
      return ConvertVoidResult(result);
    else if (origRetTy->isVectorTy())
      return ConvertVectorResult(origRetTy, result);
    else
      return ConvertScalarResult(origRetTy, result);
  }

  // Void result does not need any conversion.
  Value *ConvertVoidResult(Value *result) {
    return result;
  }

  // Vector result will be populated with the elements from the resource return.
  Value *ConvertVectorResult(Type *origRetTy, Value *result) {
    Type *resourceRetTy = result->getType();
    assert(origRetTy->isVectorTy());
    assert(resourceRetTy->isStructTy() && "expected resource return type to be a struct");
    
    const unsigned vectorSize = origRetTy->getVectorNumElements();
    const unsigned structSize = resourceRetTy->getStructNumElements();
    const unsigned size = std::min(vectorSize, structSize);
    assert(vectorSize < structSize);
    
    // Copy resource struct elements to vector.
    Value *vector = UndefValue::get(origRetTy);
    for (unsigned i = 0; i < size; ++i) {
      Value *element = m_builder.CreateExtractValue(result, { i });
      vector = m_builder.CreateInsertElement(vector, element, i);
    }

    return vector;
  }

  // Scalar result will be populated with the first element of the resource return.
  Value *ConvertScalarResult(Type *origRetTy, Value *result) {
    assert(origRetTy->isSingleValueType());
    return m_builder.CreateExtractValue(result, { 0 });
  }

};

// Translate function return and argument types for resource method lowering.
class ResourceFunctionTypeTranslator : public FunctionTypeTranslator {
public:
  ResourceFunctionTypeTranslator(OP &hlslOp) : m_hlslOp(hlslOp) {}

  // Translate return type as follows:
  //
  // void     -> void
  // <N x ty> -> dx.types.ResRet.ty
  //  ty      -> dx.types.ResRet.ty
  virtual Type *TranslateReturnType(CallInst *CI) override {
    Type *RetTy = CI->getType();
    if (RetTy->isVoidTy())
      return RetTy;
    else if (RetTy->isVectorTy())
      RetTy = RetTy->getVectorElementType();

    return m_hlslOp.GetResRetType(RetTy);
  }
  
  // Translate argument type as follows:
  //
  // resource -> dx.types.Handle
  // <N x ty> -> { ty, N }
  //  ty      -> { ty, 1 }
  virtual ArgumentType TranslateArgumentType(Value *OrigArg) override {
    int count = 1;
    Type *ty = OrigArg->getType();

    if (ty->isVectorTy()) {
      count = ty->getVectorNumElements();
      ty = ty->getVectorElementType();
    }

    return ArgumentType(ty, count);
  }

private:
  OP& m_hlslOp;
};

Value *ExtensionLowering::Resource(CallInst *CI) {
  // Extra strategy info overrides the default lowering for resource methods.
  if (!m_extraStrategyInfo.empty())
  {
    return CustomResource(CI);
  }

  ResourceFunctionTypeTranslator resourceTypeTranslator(m_hlslOp);
  Function *resourceFunction = FunctionTranslator::GetLoweredFunction(resourceTypeTranslator, CI, *this);
  if (!resourceFunction)
    return NoTranslation(CI);

  ResourceMethodCall explode(CI);
  Value *result = explode.Generate(resourceFunction);
  return result;
}

// This class handles the core logic for custom lowering of resource
// method intrinsics. The goal is to allow resource extension intrinsics
// to be handled the same way as the core hlsl resource intrinsics.
//
// Specifically, we want to support:
//
//  1. Multiple hlsl overloads map to a single dxil intrinsic
//  2. The hlsl overloads can take different parameters for a given resource type
//  3. The hlsl overloads are not consistent across different resource types 
//
// To achieve these goals we need a more complex mechanism for describing how
// to translate the high-level arguments to arguments for a dxil function.
// The custom lowering info describes this lowering using the following format.
//
// [Custom Lowering Info Format]
// A json string encoding a map where each key is either a specific resource type or
// the keyword "default" to be used for any other resource. The value is a
// a custom-format string encoding how high-level arguments are mapped to
// dxil intrinsic arguments.
//
// [Argument Translation Format]
// A comma separated string where the number of fields is exactly equal to the number
// of parameters in the target dxil intrinsic. Each field describes how to generate
// the argument for that dxil intrinsic parameter. It has the following format where
// the hl_arg_index is mandatory, but the other two parts are optional.
//
//      <hl_arg_index>.<vector_index>:<optional_type_info>
//
// The format is precisely described by the following regular expression:
//
//      (?P<hl_arg_index>[-0-9]+)(.(?P<vector_index>[-0-9]+))?(:(?P<optional_type_info>\?i32|\?i16|\?i8|\?float|\?half))?$
//
// Example
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Say we want to define the MyTextureOp extension with the following overloads:
//
// Texture1D
//  MyTextureOp(uint addr, uint offset)
//  MyTextureOp(uint addr, uint offset, uint val)
//
// Texture2D
//  MyTextureOp(uint2 addr, uint2 val)
//  
// And a dxil intrinsic defined as follows
//  @MyTextureOp(i32 opcode,  %dx.types.Handle handle, i32 addr0, i32 addr1, i32 offset, i32 val0, i32 val1)
//
// Then we would define the lowering info json as follows
//
//  {
//      "default"   : "0, 1, 2.0, 2.1,  3     , 4.0:?i32, 4.1:?i32"
//      "Texture2D" : "0, 1, 2.0, 2.1, -1:?i32, 3.0     , 3.1\"
//  }
//
//
//  This would produce the following lowerings (assuming the MyTextureOp opcode is 17)
//
//  hlsl: Texture1D.MyTextureOp(a, b)
//  hl:   @MyTextureOp(17, handle, a, b)
//  dxil: @MyTextureOp(17, handle, a, undef, b, undef, undef)
//
//  hlsl: Texture1D.MyTextureOp(a, b, c)
//  hl:   @MyTextureOp(17, handle, a, b, c)
//  dxil: @MyTextureOp(17, handle, a, undef, b, c, undef)
//
//  hlsl: Texture2D.MyTextureOp(a, c)
//  hl:   @MyTextureOp(17, handle, a, c)
//  dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
//
// 
class CustomResourceLowering
{
public:
    CustomResourceLowering(StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
    {
        // Parse lowering info json format.
        std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
            ParseLoweringInfo(LoweringInfo, CI->getContext());

        // Lookup resource kind based on handle (first arg after hl opcode)
        enum {RESOURCE_HANDLE_ARG=1};
        const char *pName = nullptr;
        if (!ResourceLookup.GetResourceKindName(CI->getArgOperand(RESOURCE_HANDLE_ARG), &pName))
        {
            ThrowExtensionError("Failed to find resource from handle");
        }
        std::string Name(pName);

        // Select lowering info to use based on resource kind.
        const char *DefaultInfoName = "default";
        std::vector<DxilArgInfo> *pArgInfo = nullptr;
        if (LoweringInfoMap.count(Name))
        {
            pArgInfo = &LoweringInfoMap.at(Name);
        }
        else if (LoweringInfoMap.count(DefaultInfoName))
        {
            pArgInfo = &LoweringInfoMap.at(DefaultInfoName);
        }
        else
        {
            ThrowExtensionError("Unable to find lowering info for resource");
        }
        GenerateLoweredArgs(CI, *pArgInfo);
    }

    const std::vector<Value *> &GetLoweredArgs() const
    {
        return m_LoweredArgs;
    }

private:
    struct OptionalTypeSpec
    {
        const char* TypeName;
        Type *LLVMType;
    };

    // These are the supported optional types for generating dxil parameters
    // that have no matching argument in the high-level intrinsic overload.
    // See [Argument Translation Format] for details.
    void InitOptionalTypes(LLVMContext &Ctx)
    {
        // Table of supported optional types.
        // Keep in sync with m_OptionalTypes small vector size to avoid
        // dynamic allocation.
        OptionalTypeSpec OptionalTypes[] = {
            {"?i32",   Type::getInt32Ty(Ctx)},
            {"?float", Type::getFloatTy(Ctx)},
            {"?half",  Type::getHalfTy(Ctx)},
            {"?i8",    Type::getInt8Ty(Ctx)},
            {"?i16",   Type::getInt16Ty(Ctx)},
        };
        DXASSERT(m_OptionalTypes.empty(), "Init should only be called once");
        m_OptionalTypes.clear();
        m_OptionalTypes.reserve(_countof(OptionalTypes));

        for (const OptionalTypeSpec &T : OptionalTypes)
        {
            m_OptionalTypes.push_back(T);
        }
    }

    Type *ParseOptionalType(StringRef OptionalTypeInfo)
    {
        if (OptionalTypeInfo.empty())
        {
            return nullptr;
        }

        for (OptionalTypeSpec &O : m_OptionalTypes)
        {
            if (OptionalTypeInfo == O.TypeName)
            {
                return O.LLVMType;
            }
        }
            
        ThrowExtensionError("Failed to parse optional type");
    }
    
    // Mapping from high level function arg to dxil function arg.
    //
    // The `HighLevelArgIndex` is the index of the function argument to
    // which this dxil argument maps.
    //
    // If `HasVectorIndex` is true then the `VectorIndex` contains the
    // index of the element in the vector pointed to by HighLevelArgIndex.
    //
    // The `OptionalType` is used to specify types for arguments that are not
    // present in all overloads of the high level function. This lets us
    // map multiple high level functions to a single dxil extension intrinsic.
    //
    struct DxilArgInfo
    {
        unsigned HighLevelArgIndex = 0;
        unsigned VectorIndex = 0;
        bool HasVectorIndex = false;
        Type *OptionalType = nullptr;
    };
    typedef std::string ResourceKindName;

    // Convert the lowering info to a machine-friendly format.
    // Note that we use the YAML parser to parse the JSON since JSON
    // is a subset of YAML (and this llvm has no JSON parser).
    //
    // See [Custom Lowering Info Format] for details.
    std::map<ResourceKindName, std::vector<DxilArgInfo>> ParseLoweringInfo(StringRef LoweringInfo, LLVMContext &Ctx)
    {
        InitOptionalTypes(Ctx);
        std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap;

        SourceMgr SM;
        yaml::Stream YAMLStream(LoweringInfo, SM);

        // Make sure we have a valid json input.
        llvm::yaml::document_iterator I = YAMLStream.begin();
        if (I == YAMLStream.end()) {
            ThrowExtensionError("Found empty resource lowering JSON.");
        }
        llvm::yaml::Node *Root = I->getRoot();
        if (!Root) {
            ThrowExtensionError("Error parsing resource lowering JSON.");
        }

        // Parse the top level map object.
        llvm::yaml::MappingNode *Object = dyn_cast<llvm::yaml::MappingNode>(Root);
        if (!Object) {
            ThrowExtensionError("Expected map in top level of resource lowering JSON.");
        }

        // Parse all key/value pairs from the map.
        for (llvm::yaml::MappingNode::iterator KVI = Object->begin(),
            KVE = Object->end();
            KVI != KVE; ++KVI) 
        {
            // Parse key.
            llvm::yaml::ScalarNode *KeyString =
                dyn_cast_or_null<llvm::yaml::ScalarNode>((*KVI).getKey());
            if (!KeyString) {
                ThrowExtensionError("Expected string as key in resource lowering info JSON map.");
            }
            SmallString<32> KeyStorage;
            StringRef Key = KeyString->getValue(KeyStorage);

            // Parse value.
            llvm::yaml::ScalarNode *ValueString =
                dyn_cast_or_null<llvm::yaml::ScalarNode>((*KVI).getValue());
            if (!ValueString) {
                ThrowExtensionError("Expected string as value in resource lowering info JSON map.");
            }
            SmallString<128> ValueStorage;
            StringRef Value = ValueString->getValue(ValueStorage);

            // Parse dxil arg info from value.
            LoweringInfoMap[Key] = ParseDxilArgInfo(Value, Ctx);
        }

        return LoweringInfoMap;
    }


    // Parse the dxail argument translation info.
    // See [Argument Translation Format] for details.
    std::vector<DxilArgInfo> ParseDxilArgInfo(StringRef ArgSpec, LLVMContext &Ctx)
    {
        std::vector<DxilArgInfo> Args;

        SmallVector<StringRef, 14> Splits;
        ArgSpec.split(Splits, ",");

        for (const StringRef Split : Splits)
        {
            StringRef Field = Split.trim();
            StringRef HighLevelArgInfo;
            StringRef OptionalTypeInfo;
            std::tie(HighLevelArgInfo, OptionalTypeInfo) = Field.split(":");

            Type *OptionalType = ParseOptionalType(OptionalTypeInfo);

            StringRef HighLevelArgIndex;
            StringRef VectorIndex;
            std::tie(HighLevelArgIndex, VectorIndex) = HighLevelArgInfo.split(".");

            // Parse the arg and vector index.
            // Parse the values as signed integers, but store them as unsigned values to
            // allows using -1 as a shorthand for the max value.
            DxilArgInfo ArgInfo;
            ArgInfo.HighLevelArgIndex = static_cast<unsigned>(std::stoi(HighLevelArgIndex));
            if (!VectorIndex.empty())
            {
                ArgInfo.HasVectorIndex = true;
                ArgInfo.VectorIndex = static_cast<unsigned>(std::stoi(VectorIndex));
            }
            ArgInfo.OptionalType = OptionalType;

            Args.push_back(ArgInfo);
        }

        return Args;
    }

    // Create the dxil args based on custom lowering info.
    void GenerateLoweredArgs(CallInst *CI, const std::vector<DxilArgInfo> &ArgInfoRecords)
    {
        IRBuilder<> builder(CI);
        for (const DxilArgInfo &ArgInfo : ArgInfoRecords)
        {
            // Check to see if we have the corresponding high-level arg in the overload for this call.
            if (ArgInfo.HighLevelArgIndex < CI->getNumArgOperands())
            {
                Value *Arg = CI->getArgOperand(ArgInfo.HighLevelArgIndex);
                if (ArgInfo.HasVectorIndex)
                {
                    // We expect a vector type here, but we handle one special case if not.
                    if (Arg->getType()->isVectorTy())
                    {
                        // We allow multiple high-level overloads to map to a single dxil extension function.
                        // If the vector index is invalid for this specific overload then use an undef
                        // value as a replacement.
                        if (ArgInfo.VectorIndex < Arg->getType()->getVectorNumElements())
                        {
                            Arg = builder.CreateExtractElement(Arg, ArgInfo.VectorIndex);
                        }
                        else
                        {
                            Arg = UndefValue::get(Arg->getType()->getVectorElementType());
                        }
                    }
                    else
                    {
                        // If it is a non-vector type then we replace non-zero vector index with
                        // undef. This is to handle hlsl intrinsic overloading rules that allow
                        // scalars in place of single-element vectors. We assume here that a non-vector
                        // means that a single element vector was already scalarized.
                        // 
                        if (ArgInfo.VectorIndex > 0)
                        {
                            Arg = UndefValue::get(Arg->getType());
                        }
                    }
                }

                m_LoweredArgs.push_back(Arg);
            }
            else if (ArgInfo.OptionalType)
            {
                // If there was no matching high-level arg then we look for the optional
                // arg type specified by the lowering info.
                m_LoweredArgs.push_back(UndefValue::get(ArgInfo.OptionalType));
            }
            else
            { 
                // No way to know how to generate the correc type for this dxil arg.
                ThrowExtensionError("Unable to map high-level arg to dxil arg");
            }
        }
    }
    
    std::vector<Value *> m_LoweredArgs;
    SmallVector<OptionalTypeSpec, 5> m_OptionalTypes;
};

// Boilerplate to reuse exising logic as much as possible.
// We just want to overload GetFunctionType here.
class CustomResourceFunctionTranslator : public FunctionTranslator {
public:
  static Function *GetLoweredFunction(
        const CustomResourceLowering &CustomLowering,
        ResourceFunctionTypeTranslator &typeTranslator,
        CallInst *CI,
        ExtensionLowering &lower
    )
  {
      CustomResourceFunctionTranslator T(CustomLowering, typeTranslator, lower);
      return T.FunctionTranslator::GetLoweredFunction(CI);
  }

private:
    CustomResourceFunctionTranslator(
        const CustomResourceLowering &CustomLowering,
        ResourceFunctionTypeTranslator &typeTranslator,
        ExtensionLowering &lower
    )
        : FunctionTranslator(typeTranslator, lower)
        , m_CustomLowering(CustomLowering)
    {
    }

    virtual FunctionType *GetFunctionType(CallInst *CI, Type *RetTy) override {
        SmallVector<Type *, 16> ParamTypes;
        for (Value *V : m_CustomLowering.GetLoweredArgs())
        {
            ParamTypes.push_back(V->getType());
        }
        const bool IsVarArg = false;
        return FunctionType::get(RetTy, ParamTypes, IsVarArg);
    }

private:
    const CustomResourceLowering &m_CustomLowering;
};

// Boilerplate to reuse exising logic as much as possible.
// We just want to overload Generate here.
class CustomResourceMethodCall : public ResourceMethodCall
{
public:
    CustomResourceMethodCall(CallInst *CI, const CustomResourceLowering &CustomLowering)
        : ResourceMethodCall(CI)
        , m_CustomLowering(CustomLowering)
    {}

    virtual Value *Generate(Function *loweredFunction) override {
        Value *result = CreateCall(loweredFunction, m_CustomLowering.GetLoweredArgs());
        result = ConvertResult(result);
        return result;
    }

private:
    const CustomResourceLowering &m_CustomLowering;
};

// Support custom lowering logic for resource functions.
Value *ExtensionLowering::CustomResource(CallInst *CI) {
    CustomResourceLowering CustomLowering(m_extraStrategyInfo, CI, m_hlResourceLookup);
    ResourceFunctionTypeTranslator ResourceTypeTranslator(m_hlslOp);
    Function *ResourceFunction = CustomResourceFunctionTranslator::GetLoweredFunction(
        CustomLowering,
        ResourceTypeTranslator,
        CI,
        *this
    );
    if (!ResourceFunction)
        return NoTranslation(CI);

    CustomResourceMethodCall custom(CI, CustomLowering);
    Value *Result = custom.Generate(ResourceFunction);
    return Result;
}

///////////////////////////////////////////////////////////////////////////////
// Dxil Lowering.

Value *ExtensionLowering::Dxil(CallInst *CI) {
  // Map the extension opcode to the corresponding dxil opcode.
  unsigned extOpcode = GetHLOpcode(CI);
  OP::OpCode dxilOpcode;
  if (!m_helper->GetDxilOpcode(extOpcode, dxilOpcode))
    return nullptr;

  // Find the dxil function based on the overload type.
  Type *overloadTy = OP::GetOverloadType(dxilOpcode, CI->getCalledFunction());
  Function *F = m_hlslOp.GetOpFunc(dxilOpcode, overloadTy->getScalarType());

  // Update the opcode in the original call so we can just copy it below.
  // We are about to delete this call anyway.
  CI->setOperand(0, m_hlslOp.GetI32Const(static_cast<unsigned>(dxilOpcode)));

  // Create the new call.
  Value *result = nullptr;
  if (overloadTy->isVectorTy()) {
    ReplicateCall replicate(CI, *F);
    result = replicate.Generate();
  }
  else {
    IRBuilder<> builder(CI);
    SmallVector<Value *, 8> args(CI->arg_operands().begin(), CI->arg_operands().end());
    result = builder.CreateCall(F, args);
  }

  return result;
}

///////////////////////////////////////////////////////////////////////////////
// Computing Extension Names.

// Compute the name to use for the intrinsic function call once it is lowered to dxil.
// First checks to see if we have a custom name from the codegen helper and if not
// chooses a default name based on the lowergin strategy.
class ExtensionName {
public:
  ExtensionName(CallInst *CI, ExtensionLowering::Strategy strategy, HLSLExtensionsCodegenHelper *helper)
    : m_CI(CI)
    , m_strategy(strategy)
    , m_helper(helper)
  {}

  std::string Get() {
    std::string name;
    if (m_helper)
      name = GetCustomExtensionName(m_CI, *m_helper);

    if (!HasCustomExtensionName(name))
      name = GetDefaultCustomExtensionName(m_CI, ExtensionLowering::GetStrategyName(m_strategy));

    return name;
  }

private:
  CallInst *m_CI;
  ExtensionLowering::Strategy m_strategy;
  HLSLExtensionsCodegenHelper *m_helper;

  static std::string GetCustomExtensionName(CallInst *CI, HLSLExtensionsCodegenHelper &helper) {
    unsigned opcode = GetHLOpcode(CI);
    std::string name = helper.GetIntrinsicName(opcode);
    ReplaceOverloadMarkerWithTypeName(name, CI);

    return name;
  }

  static std::string GetDefaultCustomExtensionName(CallInst *CI, StringRef strategyName) {
    return (Twine(CI->getCalledFunction()->getName()) + "." + Twine(strategyName)).str();
  }

  static bool HasCustomExtensionName(const std::string name) {
    return name.size() > 0;
  }

  typedef unsigned OverloadArgIndex;
  static constexpr OverloadArgIndex DefaultOverloadIndex = std::numeric_limits<OverloadArgIndex>::max();

  // Choose the (return value or argument) type that determines the overload type
  // for the intrinsic call.
  // If the overload arg index was explicitly specified (see ParseOverloadArgIndex)
  // then we use that arg to pick the overload name. Otherwise we pick a default
  // where we take the return type as the overload. If the return is void we
  // take the first (non-opcode) argument as the overload type.
  static Type *SelectOverloadSlot(CallInst *CI, OverloadArgIndex ArgIndex) {
   if (ArgIndex != DefaultOverloadIndex)
    {
      return CI->getArgOperand(ArgIndex)->getType();
    }

    Type *ty = CI->getType();
    if (ty->isVoidTy()) {
      if (CI->getNumArgOperands() > 1)
        ty = CI->getArgOperand(1)->getType(); // First non-opcode argument.
    }

    return ty;
  }

  static Type *GetOverloadType(CallInst *CI, OverloadArgIndex ArgIndex) {
    Type *ty = SelectOverloadSlot(CI, ArgIndex);
    if (ty->isVectorTy())
      ty = ty->getVectorElementType();

    return ty;
  }

  static std::string GetTypeName(Type *ty) {
      std::string typeName;
      llvm::raw_string_ostream os(typeName);
      ty->print(os);
      os.flush();
      return typeName;
  }

  static std::string GetOverloadTypeName(CallInst *CI, OverloadArgIndex ArgIndex) {
    Type *ty = GetOverloadType(CI, ArgIndex);
    return GetTypeName(ty);
  }

  // Parse the arg index out of the overload marker (if any).
  //
  // The function names use a $o to indicate that the function is overloaded
  // and we should replace $o with the overload type. The extension name can
  // explicitly set which arg to use for the overload type by adding a colon
  // and a number after the $o (e.g. $o:3 would say the overload type is
  // determined by parameter 3).
  //
  // If we find an arg index after the overload marker we update the size
  // of the marker to include the full parsed string size so that it can
  // be replaced with the selected overload type.
  //
  static OverloadArgIndex ParseOverloadArgIndex(
      const std::string& functionName,
      size_t OverloadMarkerStartIndex,
      size_t *pOverloadMarkerSize)
  {
      assert(OverloadMarkerStartIndex != std::string::npos);
      size_t StartIndex = OverloadMarkerStartIndex + *pOverloadMarkerSize;

      // Check if we have anything after the overload marker to parse.
      if (StartIndex >= functionName.size())
      {
          return DefaultOverloadIndex;
      }

      // Does it start with a ':' ?
      if (functionName[StartIndex] != ':')
      {
          return DefaultOverloadIndex;
      }

      // Skip past the :
      ++StartIndex;

      // Collect all the digits.
      std::string Digits;
      Digits.reserve(functionName.size() - StartIndex);
      for (size_t i = StartIndex; i < functionName.size(); ++i)
      {
          char c = functionName[i];
          if (!isdigit(c))
          {
              break;
          }
          Digits.push_back(c);
      }

      if (Digits.empty())
      {
          return DefaultOverloadIndex;
      }

      *pOverloadMarkerSize = *pOverloadMarkerSize + std::strlen(":") + Digits.size();
      return std::stoi(Digits);
  }

  // Find the occurence of the overload marker $o and replace it the the overload type name.
  static void ReplaceOverloadMarkerWithTypeName(std::string &functionName, CallInst *CI) {
    const char *OverloadMarker = "$o";
    size_t OverloadMarkerLength = 2;

    size_t pos = functionName.find(OverloadMarker);
    if (pos != std::string::npos) {
      OverloadArgIndex ArgIndex = ParseOverloadArgIndex(functionName, pos, &OverloadMarkerLength);
      std::string typeName = GetOverloadTypeName(CI, ArgIndex);
      functionName.replace(pos, OverloadMarkerLength, typeName);
    }
  }
};

std::string ExtensionLowering::GetExtensionName(llvm::CallInst *CI) {
  ExtensionName name(CI, m_strategy, m_helper);
  return name.Get();
}
